{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "e0056726-3359-4a9b-9913-016617525a6d",
   "metadata": {},
   "source": [
    "# Scalable late interaction vectors in Elasticsearch: Bit Vectors #\n",
    "\n",
    "In this notebook, we will be looking at how to convert late interaction vectors to bit vectors to \n",
    "1. Save siginificant disk space  \n",
    "2. Lower query latency\n",
    "   \n",
    "We will also look at how we can use hamming distance to speed our queries up even further.  \n",
    "This notebook builds on part 1 where we downloaded the images, created ColPali vectors and saved them to disk. Please execute this notebook before trying the techniques in this notebook.  \n",
    " \n",
    "Also check out our accompanying blog post on [Scaling Late Interaction Models](TODO) for more context on this notebook. "
   ]
  },
  {
   "cell_type": "markdown",
   "id": "49dbcc61-5dab-4cf6-bbc5-7fa898707ce6",
   "metadata": {},
   "source": [
    "This is the key part of this notebook. We use the `to_bit_vectors()` function to convert our vectors into bit vectors.  \n",
    "The function is simple in essence. Values `> 0` are converted to `1`, values `< 0` are converted to `0`. We then convert our array of `0`s and `1`s to a hex string, that represents our bit vector.  \n",
    "So don't be surprised that the values that we will be indexing look like strings and not arrays as before. This is intended!  \n",
    "\n",
    "Learn more about [bit vectors and hamming distance in our blog](https://www.elastic.co/search-labs/blog/bit-vectors-in-elasticsearch) about this topic. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "be6ffdc5-fbaa-40b5-8b33-5540a3f957ba",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "\n",
    "\n",
    "def to_bit_vectors(embeddings: list) -> list:\n",
    "    return [\n",
    "        np.packbits(np.where(np.array(embedding) > 0, 1, 0))\n",
    "        .astype(np.int8)\n",
    "        .tobytes()\n",
    "        .hex()\n",
    "        for embedding in embeddings\n",
    "    ]"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "52b7449b-8fbf-46b7-90c9-330070f6996a",
   "metadata": {},
   "source": [
    "Here we are defining our mapping for our Elasticsearch index. Note how we set the `element_type` parameter to `bit` to inform Elasticsearch that we will be indexing bit vectors in this field. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "2de5872d-b372-40fe-85c5-111b9f9fa6c8",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[INFO] Creating index: searchlabs-colpali-hamming\n"
     ]
    }
   ],
   "source": [
    "import os\n",
    "from dotenv import load_dotenv\n",
    "from elasticsearch import Elasticsearch\n",
    "\n",
    "load_dotenv(\"elastic.env\")\n",
    "\n",
    "ELASTIC_API_KEY = os.getenv(\"ELASTIC_API_KEY\")\n",
    "ELASTIC_HOST = os.getenv(\"ELASTIC_HOST\")\n",
    "INDEX_NAME = \"searchlabs-colpali-hamming\"\n",
    "\n",
    "es = Elasticsearch(ELASTIC_HOST, api_key=ELASTIC_API_KEY)\n",
    "\n",
    "mappings = {\n",
    "    \"mappings\": {\n",
    "        \"properties\": {\n",
    "            \"col_pali_vectors\": {\"type\": \"rank_vectors\", \"element_type\": \"bit\"}\n",
    "        }\n",
    "    }\n",
    "}\n",
    "\n",
    "if not es.indices.exists(index=INDEX_NAME):\n",
    "    print(f\"[INFO] Creating index: {INDEX_NAME}\")\n",
    "    es.indices.create(index=INDEX_NAME, body=mappings)\n",
    "else:\n",
    "    print(f\"[INFO] Index '{INDEX_NAME}' already exists.\")\n",
    "\n",
    "\n",
    "def index_document(es_client, index, doc_id, document, retries=10, initial_backoff=1):\n",
    "    for attempt in range(1, retries + 1):\n",
    "        try:\n",
    "            return es_client.index(index=index, id=doc_id, document=document)\n",
    "        except Exception as e:\n",
    "            if attempt < retries:\n",
    "                wait_time = initial_backoff * (2 ** (attempt - 1))\n",
    "                print(f\"[WARN] Failed to index {doc_id} (attempt {attempt}): {e}\")\n",
    "                time.sleep(wait_time)\n",
    "            else:\n",
    "                print(f\"Failed to index {doc_id} after {retries} attempts: {e}\")\n",
    "                raise"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "bdf6ff33-3e22-43c1-9f3e-c3dd663b40e2",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "8be1b809674143c486705f1699f440dd",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Indexing documents:   0%|          | 0/500 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Completed indexing 500 documents\n"
     ]
    }
   ],
   "source": [
    "from concurrent.futures import ThreadPoolExecutor\n",
    "from tqdm.notebook import tqdm\n",
    "import pickle\n",
    "\n",
    "\n",
    "def process_file(file_name, vectors):\n",
    "    if es.exists(index=INDEX_NAME, id=file_name):\n",
    "        return\n",
    "\n",
    "    bit_vectors = to_bit_vectors(vectors)\n",
    "\n",
    "    index_document(\n",
    "        es_client=es,\n",
    "        index=INDEX_NAME,\n",
    "        doc_id=file_name,\n",
    "        document={\"col_pali_vectors\": bit_vectors},\n",
    "    )\n",
    "\n",
    "\n",
    "with open(\"col_pali_vectors.pkl\", \"rb\") as f:\n",
    "    file_to_multi_vectors = pickle.load(f)\n",
    "\n",
    "with ThreadPoolExecutor(max_workers=10) as executor:\n",
    "    list(\n",
    "        tqdm(\n",
    "            executor.map(\n",
    "                lambda item: process_file(*item), file_to_multi_vectors.items()\n",
    "            ),\n",
    "            total=len(file_to_multi_vectors),\n",
    "            desc=\"Indexing documents\",\n",
    "        )\n",
    "    )\n",
    "\n",
    "print(f\"Completed indexing {len(file_to_multi_vectors)} documents\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "1dfc3713-d649-46db-aa81-171d6d92668e",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "63a206f0fc2d491196b83b450eb4b93a",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "import torch\n",
    "from PIL import Image\n",
    "from colpali_engine.models import ColPali, ColPaliProcessor\n",
    "\n",
    "model_name = \"vidore/colpali-v1.3\"\n",
    "model = ColPali.from_pretrained(\n",
    "    \"vidore/colpali-v1.3\",\n",
    "    torch_dtype=torch.float32,\n",
    "    device_map=\"mps\",  # \"mps\" for Apple Silicon, \"cuda\" if available, \"cpu\" otherwise\n",
    ").eval()\n",
    "\n",
    "col_pali_processor = ColPaliProcessor.from_pretrained(model_name)\n",
    "\n",
    "\n",
    "def create_col_pali_query_vectors(query: str) -> list:\n",
    "    queries = col_pali_processor.process_queries([query]).to(model.device)\n",
    "    with torch.no_grad():\n",
    "        return model(**queries).tolist()[0]"
   ]
  },
  {
   "cell_type": "raw",
   "id": "5e86697d-d9dd-4224-85c8-023c71c88548",
   "metadata": {},
   "source": [
    "Here we run the search against our index comparing our query vector converted to bit vectors to the bit vectors in our index.  \n",
    "Trading of a bit of accuracy, this is allows us to use hamming distance (`maxSimInvHamming(...)`), which is able to leverage optimzations such as bit-masks, SIMD, etc. Again - learn more about [bit vectors and hamming distance in our blog](https://www.elastic.co/search-labs/blog/bit-vectors-in-elasticsearch) about this topic. \n",
    "\n",
    "See the cell below about a different technique to query our bit vectors. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "8e322b23-b4bc-409d-9e00-2dab93f6a295",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{\"_source\": false, \"query\": {\"script_score\": {\"query\": {\"match_all\": {}}, \"script\": {\"source\": \"maxSimInvHamming(params.query_vector, 'col_pali_vectors')\", \"params\": {\"query_vector\": [\"7747bcd9732859c3645aa81036f5c960\", \"729b3c418ba8594a67daa042eca1c961\", \"609e3d8a2ac379c2204aa0cfa8345bdc\", \"30bf378a2ac279da245aa8dfa83c3bdc\", \"64af77ea2acdf9c28c0aa5df863677f4\", \"686f3fce2ac871c26e6aaddf023455ec\", \"383f31a8e8c0f8ca2c4ab54f047c7dec\", \"203b33caaac279da0acaa54f8a3c6bcc\", \"319a63eba8d279ca30dbbccf8f757b8e\", \"203b73ca28d2798a325bb44f8c3c5bce\", \"203bb7caa8d2718a1a4bb14f8a3c5bdc\", \"203bb7caa8d2798a1a6aa14f8a3c5fdc\", \"303b33caa8d2798a0a4aa14f8a3c5bdc\", \"303b33caaad379ca0e4aa14f8a3c5bdc\", \"709b33caaac379ca0c4aa14f8a3c5fdc\", \"708e37eaaac779ca2c4aa1df863c1fdc\", \"648e77ea6acd79caac4ae1df86363ffc\", \"648e77ea6acdf9caac4ae5df06363ffc\", \"608f37ea2ac579ca2c4ea1df063c3ffc\", \"709f37c8aac379ca2c4ea1df863c1fdc\", \"70af31c82ac671ce2c6ab14fc43c1bfc\"]}}}}, \"size\": 5}\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<div style='display: flex; flex-wrap: wrap; align-items: flex-start;'><img src=\"searchlabs-colpali/image_104.jpg\" alt=\"image_104.jpg\" style=\"max-width:300px; height:auto; margin:10px;\"><img src=\"searchlabs-colpali/image_3.jpg\" alt=\"image_3.jpg\" style=\"max-width:300px; height:auto; margin:10px;\"><img src=\"searchlabs-colpali/image_12.jpg\" alt=\"image_12.jpg\" style=\"max-width:300px; height:auto; margin:10px;\"><img src=\"searchlabs-colpali/image_2.jpg\" alt=\"image_2.jpg\" style=\"max-width:300px; height:auto; margin:10px;\"><img src=\"searchlabs-colpali/image_110.jpg\" alt=\"image_110.jpg\" style=\"max-width:300px; height:auto; margin:10px;\"></div>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "from IPython.display import display, HTML\n",
    "import os\n",
    "import json\n",
    "\n",
    "DOCUMENT_DIR = \"searchlabs-colpali\"\n",
    "\n",
    "query = \"What do companies use for recruiting?\"\n",
    "query_vector = to_bit_vectors(create_col_pali_query_vectors(query))\n",
    "es_query = {\n",
    "    \"_source\": False,\n",
    "    \"query\": {\n",
    "        \"script_score\": {\n",
    "            \"query\": {\"match_all\": {}},\n",
    "            \"script\": {\n",
    "                \"source\": \"maxSimInvHamming(params.query_vector, 'col_pali_vectors')\",\n",
    "                \"params\": {\"query_vector\": query_vector},\n",
    "            },\n",
    "        }\n",
    "    },\n",
    "    \"size\": 5,\n",
    "}\n",
    "print(json.dumps(es_query))\n",
    "\n",
    "results = es.search(index=INDEX_NAME, body=es_query)\n",
    "image_ids = [hit[\"_id\"] for hit in results[\"hits\"][\"hits\"]]\n",
    "\n",
    "html = \"<div style='display: flex; flex-wrap: wrap; align-items: flex-start;'>\"\n",
    "for image_id in image_ids:\n",
    "    image_path = os.path.join(DOCUMENT_DIR, image_id)\n",
    "    html += f'<img src=\"{image_path}\" alt=\"{image_id}\" style=\"max-width:300px; height:auto; margin:10px;\">'\n",
    "html += \"</div>\"\n",
    "\n",
    "display(HTML(html))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e27b68ac-bec8-4415-919e-8b916bc35816",
   "metadata": {},
   "source": [
    "Above we have seen how to query our data using the `maxSimInvHamming(...)` function.  \n",
    "We can also just pass the full fidelity col pali vector and use the `maxSimDotProduct(...)` function for [asymmetric similarity](https://www.elastic.co/guide/en/elasticsearch/reference/8.18/rank-vectors.html#rank-vectors-scoring) between the vectors. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "32fd9ee4-d7c6-4954-a766-7b06735290ff",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div style='display: flex; flex-wrap: wrap; align-items: flex-start;'><img src=\"searchlabs-colpali/image_104.jpg\" alt=\"image_104.jpg\" style=\"max-width:300px; height:auto; margin:10px;\"><img src=\"searchlabs-colpali/image_3.jpg\" alt=\"image_3.jpg\" style=\"max-width:300px; height:auto; margin:10px;\"><img src=\"searchlabs-colpali/image_2.jpg\" alt=\"image_2.jpg\" style=\"max-width:300px; height:auto; margin:10px;\"><img src=\"searchlabs-colpali/image_12.jpg\" alt=\"image_12.jpg\" style=\"max-width:300px; height:auto; margin:10px;\"><img src=\"searchlabs-colpali/image_92.jpg\" alt=\"image_92.jpg\" style=\"max-width:300px; height:auto; margin:10px;\"></div>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "query = \"What do companies use for recruiting?\"\n",
    "query_vector = create_col_pali_query_vectors(query)\n",
    "es_query = {\n",
    "    \"_source\": False,\n",
    "    \"query\": {\n",
    "        \"script_score\": {\n",
    "            \"query\": {\"match_all\": {}},\n",
    "            \"script\": {\n",
    "                \"source\": \"maxSimDotProduct(params.query_vector, 'col_pali_vectors')\",\n",
    "                \"params\": {\"query_vector\": query_vector},\n",
    "            },\n",
    "        }\n",
    "    },\n",
    "    \"size\": 5,\n",
    "}\n",
    "\n",
    "results = es.search(index=INDEX_NAME, body=es_query)\n",
    "image_ids = [hit[\"_id\"] for hit in results[\"hits\"][\"hits\"]]\n",
    "\n",
    "html = \"<div style='display: flex; flex-wrap: wrap; align-items: flex-start;'>\"\n",
    "for image_id in image_ids:\n",
    "    image_path = os.path.join(DOCUMENT_DIR, image_id)\n",
    "    html += f'<img src=\"{image_path}\" alt=\"{image_id}\" style=\"max-width:300px; height:auto; margin:10px;\">'\n",
    "html += \"</div>\"\n",
    "\n",
    "display(HTML(html))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ee8df1e3-af66-4e35-9c26-7257c281536f",
   "metadata": {},
   "outputs": [],
   "source": [
    "# We kill the kernel forcefully to free up the memory from the ColPali model.\n",
    "print(\"Shutting down the kernel to free memory...\")\n",
    "import os\n",
    "\n",
    "os._exit(0)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "dependecy-test-colpali-blog",
   "language": "python",
   "name": "dependecy-test-colpali-blog"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
